<!DOCTYPE html>
<html lang="ja">
<head>
    <meta http-equiv="content-type" content="text/html;charset=utf-8"/>
    <meta name="viewport" content="width=device-width, initial-scale=1.0"/>
    <meta name="description" content="これは、論文「スケッチドローイングのニューラル表現」から作成したSketch RNNの注釈付きPyTorch実装です。Sketch RNN は、自転車や猫などのオブジェクトのスケッチを生成するシーケンスツーシーケンスモデルです。"/>

    <meta name="twitter:card" content="summary"/>
    <meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
    <meta name="twitter:title" content="スケッチ RNN"/>
    <meta name="twitter:description" content="これは、論文「スケッチドローイングのニューラル表現」から作成したSketch RNNの注釈付きPyTorch実装です。Sketch RNN は、自転車や猫などのオブジェクトのスケッチを生成するシーケンスツーシーケンスモデルです。"/>
    <meta name="twitter:site" content="@labmlai"/>
    <meta name="twitter:creator" content="@labmlai"/>

    <meta property="og:url" content="https://nn.labml.ai/sketch_rnn/index.html"/>
    <meta property="og:title" content="スケッチ RNN"/>
    <meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
    <meta property="og:site_name" content="スケッチ RNN"/>
    <meta property="og:type" content="object"/>
    <meta property="og:title" content="スケッチ RNN"/>
    <meta property="og:description" content="これは、論文「スケッチドローイングのニューラル表現」から作成したSketch RNNの注釈付きPyTorch実装です。Sketch RNN は、自転車や猫などのオブジェクトのスケッチを生成するシーケンスツーシーケンスモデルです。"/>

    <title>スケッチ RNN</title>
    <link rel="shortcut icon" href="/icon.png"/>
    <link rel="stylesheet" href="../pylit.css?v=1">
    <link rel="canonical" href="https://nn.labml.ai/sketch_rnn/index.html"/>
    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">

    <!-- Global site tag (gtag.js) - Google Analytics -->
    <script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
    <script>
        window.dataLayer = window.dataLayer || [];

        function gtag() {
            dataLayer.push(arguments);
        }

        gtag('js', new Date());

        gtag('config', 'G-4V3HC8HBLH');
    </script>
</head>
<body>
<div id='container'>
    <div id="background"></div>
    <div class='section'>
        <div class='docs'>
            <p>
                <a class="parent" href="/">home</a>
                <a class="parent" href="index.html">sketch_rnn</a>
            </p>
            <p>
                <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
                    <img alt="Github"
                         src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
                         style="max-width:100%;"/></a>
                <a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
                    <img alt="Twitter"
                         src="https://img.shields.io/twitter/follow/labmlai?style=social"
                         style="max-width:100%;"/></a>
            </p>
            <p>
                <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/sketch_rnn/__init__.py" target="_blank">
                    View code on Github</a>
            </p>
        </div>
    </div>
    <div class='section' id='section-0'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-0'>#</a>
            </div>
            <h1>スケッチ RNN</h1>
<p>これは、論文「<a href="https://papers.labml.ai/paper/1704.03477">スケッチ図面のニューラル表現」<a href="https://pytorch.org">の注釈付きPyTorch実装です</a></a>。</p>
<p>Sketch RNN はシーケンス間の変分オートエンコーダーです。エンコーダーとデコーダーはどちらもリカレントニューラルネットワークモデルです。一連のストロークを予測することで、ストロークに基づいて簡単な図面を再構築する方法を学習します。デコーダーは、各ストロークをガウスの混合として予測します</p>。
<h3>データ取得</h3>
<p><a href="https://github.com/googlecreativelab/quickdraw-dataset">クイック、ドロー！からデータをダウンロードデータセット</a>。<em>readmeのSketch-RNN QuickDraw <code  class="highlight"><span></span><span class="n">npz</span></code>
 Datasetセクションにファイルをダウンロードするためのリンクがあります</em>。<code  class="highlight"><span></span><span class="n">npz</span></code>
<code  class="highlight"><span></span><span class="n">data</span><span class="o">/</span><span class="n">sketch</span></code>
ダウンロードしたファイルをフォルダに配置します。<code  class="highlight"><span></span><span class="n">bicycle</span></code>
このコードはデータセットを使用するように構成されています。これは設定で変更できます。</p>
<h3>謝辞</h3>
<p><a href="https://github.com/alexis-jacq/Pytorch-Sketch-RNN"><a href="https://github.com/alexis-jacq">アレクシス・デイヴィッド・ジャックによるPyTorch Sketch RNNNプロジェクトの協力を得ました</a></a></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">32</span><span></span><span class="kn">import</span> <span class="nn">math</span>
<span class="lineno">33</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Any</span>
<span class="lineno">34</span>
<span class="lineno">35</span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="lineno">36</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">37</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">38</span><span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="lineno">39</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">optim</span>
<span class="lineno">40</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span><span class="p">,</span> <span class="n">DataLoader</span>
<span class="lineno">41</span>
<span class="lineno">42</span><span class="kn">import</span> <span class="nn">einops</span>
<span class="lineno">43</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">lab</span><span class="p">,</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">monit</span>
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">labml_helpers.device</span> <span class="kn">import</span> <span class="n">DeviceConfigs</span>
<span class="lineno">45</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">46</span><span class="kn">from</span> <span class="nn">labml_helpers.optimizer</span> <span class="kn">import</span> <span class="n">OptimizerConfigs</span>
<span class="lineno">47</span><span class="kn">from</span> <span class="nn">labml_helpers.train_valid</span> <span class="kn">import</span> <span class="n">TrainValidConfigs</span><span class="p">,</span> <span class="n">hook_model_outputs</span><span class="p">,</span> <span class="n">BatchIndex</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-1'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-1'>#</a>
            </div>
            <h2>データセット</h2>
<p>このクラスは、データをロードして前処理します。</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">50</span><span class="k">class</span> <span class="nc">StrokesDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-2'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-2'>#</a>
            </div>
            <p><code  class="highlight"><span></span><span class="n">dataset</span></code>
<a href="seq_len, 3">seq_len</a>, 3 という形状のゴツゴツした配列のリストです。これは一連のストロークで、各ストロークは3つの整数で表されます。最初の 2 つは x と y に沿った変位 (<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord" style="">Δ</span><span class="mord mathnormal" style="">x</span></span></span></span></span></span>,<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8777699999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbn" style=""><span class="mord" style="">Δ</span><span class="mord mathnormal" style="margin-right:0.03588em">y</span></span></span></span></span></span>) で、最後の整数はペンの状態 (<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style="">1</span></span></span></span></span></span>紙に触れている場合とそうでない場合) を表します</p>。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqch" style=""><span class="mord" style="">0</span></span></span></span></span></span>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">57</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">max_seq_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-3'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-3'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">67</span>        <span class="n">data</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-4'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-4'>#</a>
            </div>
            <p>各シーケンスとフィルターを繰り返し処理します。</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">69</span>        <span class="k">for</span> <span class="n">seq</span> <span class="ow">in</span> <span class="n">dataset</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-5'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-5'>#</a>
            </div>
            <p>ストロークのシーケンスの長さが範囲内にある場合はフィルタリングします</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">71</span>            <span class="k">if</span> <span class="mi">10</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span> <span class="o">&lt;=</span> <span class="n">max_seq_length</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-6'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-6'>#</a>
            </div>
            <p>クランプ<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord" style="">Δ</span><span class="mord mathnormal" style="">x</span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8777699999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbn" style=""><span class="mord" style="">Δ</span><span class="mord mathnormal" style="margin-right:0.03588em">y</span></span></span></span></span></span>、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord">−</span><span class="mord coloredeq eqci" style=""><span class="mord" style="">1</span></span><span class="mord coloredeq eqch" style=""><span class="mord" style="">0</span></span><span class="mord coloredeq eqch" style=""><span class="mord" style="">0</span></span><span class="mord coloredeq eqch" style=""><span class="mord" style="">0</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style="">1</span></span><span class="mord coloredeq eqch" style=""><span class="mord" style="">0</span></span><span class="mord coloredeq eqch" style=""><span class="mord" style="">0</span></span><span class="mord coloredeq eqch" style=""><span class="mord" style="">0</span></span><span class="mclose">]</span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">73</span>                <span class="n">seq</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span><span class="n">seq</span><span class="p">,</span> <span class="mi">1000</span><span class="p">)</span>
<span class="lineno">74</span>                <span class="n">seq</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">seq</span><span class="p">,</span> <span class="o">-</span><span class="mi">1000</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-7'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-7'>#</a>
            </div>
            <p>浮動小数点配列に変換して追加 <code  class="highlight"><span></span><span class="n">data</span></code>
</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">76</span>                <span class="n">seq</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">seq</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="lineno">77</span>                <span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-8'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-8'>#</a>
            </div>
            <p>次に、(<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord" style="">Δ</span><span class="mord mathnormal" style="">x</span></span></span></span></span></span>,<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8777699999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbn" style=""><span class="mord" style="">Δ</span><span class="mord mathnormal" style="margin-right:0.03588em">y</span></span></span></span></span></span>) を組み合わせた標準偏差であるスケーリング係数を計算します。論文によると、平均はとにかくそれに近いので、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqch" style=""><span class="mord" style="">0</span></span></span></span></span></span>簡単にするために平均を調整していないという</p>。

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">83</span>        <span class="k">if</span> <span class="n">scale</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">84</span>            <span class="n">scale</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">np</span><span class="o">.</span><span class="n">ravel</span><span class="p">(</span><span class="n">s</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">])</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">data</span><span class="p">]))</span>
<span class="lineno">85</span>        <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">scale</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-9'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-9'>#</a>
            </div>
            <p>すべてのシーケンスの中で一番長いシーケンス長を取得</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">88</span>        <span class="n">longest_seq_len</span> <span class="o">=</span> <span class="nb">max</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span> <span class="k">for</span> <span class="n">seq</span> <span class="ow">in</span> <span class="n">data</span><span class="p">])</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-10'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-10'>#</a>
            </div>
            <p>PyTorchデータ配列を初期化するには、シーケンスの開始 (sos) とシーケンスの終了 (eos) の2つの追加ステップが必要です。各ステップはベクトルです<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqt" style=""><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord coloredeq eqba" style=""><span class="mord coloredeq eqbc" style=""><span class="mord coloredeq eqbm" style="">Δ</span><span class="mord mathnormal coloredeq eqbm" style="">x</span></span><span class="mpunct coloredeq eqbc" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbc" style="">Δ</span></span><span class="mord mathnormal coloredeq eqba" style="margin-right:0.03588em">y</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqcd" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqci" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct coloredeq eqbg" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct coloredeq eqbg" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqcf" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span>。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbg" style=""><span class="mord" style=""><span class="mord coloredeq eqcd" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqci" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqcf" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style="">1</span></span></span></span></span></span>そのうちの1つだけがそうで、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqch" style=""><span class="mord" style="">0</span></span></span></span></span></span>他はそうです。<em>ペンダウン、<em>ペンアップ</em></em>、<em>シーケンスの終了の順序で表されます</em>。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcd" style=""><span class="mord" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqci" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style="">1</span></span></span></span></span></span>次のステップでペンが紙に触れた場合です。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqce" style=""><span class="mord" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style="">1</span></span></span></span></span></span>次のステップでペンが紙に触れない場合です。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style="">1</span></span></span></span></span></span>それがドローイングの終わりだとしたら</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">98</span>        <span class="bp">self</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">),</span> <span class="n">longest_seq_len</span> <span class="o">+</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-11'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-11'>#</a>
            </div>
            <p>マスク配列は、次のステップを取り込んで予測するデコーダーの出力用なので、追加のステップは1つだけ必要です。<code  class="highlight"><span></span><span class="n">data</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></code>
</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">101</span>        <span class="bp">self</span><span class="o">.</span><span class="n">mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">),</span> <span class="n">longest_seq_len</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">102</span>
<span class="lineno">103</span>        <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">seq</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">data</span><span class="p">):</span>
<span class="lineno">104</span>            <span class="n">seq</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span>
<span class="lineno">105</span>            <span class="n">len_seq</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-12'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-12'>#</a>
            </div>
            <p>スケールとセット <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8777699999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqba" style=""><span class="mord" style=""><span class="mord coloredeq eqbc" style=""><span class="mord coloredeq eqbm" style="">Δ</span><span class="mord mathnormal coloredeq eqbm" style="">x</span></span><span class="mpunct coloredeq eqbc" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbc" style="">Δ</span></span><span class="mord mathnormal" style="margin-right:0.03588em">y</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">107</span>            <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span><span class="n">len_seq</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="n">scale</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-13'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-13'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcd" style=""><span class="mord" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqci" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">109</span>            <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span><span class="n">len_seq</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-14'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-14'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqce" style=""><span class="mord" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">111</span>            <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span><span class="n">len_seq</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-15'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-15'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">113</span>            <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">len_seq</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-16'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-16'>#</a>
            </div>
            <p>マスクはシーケンスの終わりまでオンです</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">115</span>            <span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:</span><span class="n">len_seq</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-17'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-17'>#</a>
            </div>
            <p>シーケンスの開始は <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqbd" style=""><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord coloredeq eqch" style="">0</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqch" style="">0</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqci" style="">1</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqch" style="">0</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqch" style="">0</span></span><span class="mclose" style="">)</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">118</span>        <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-18'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-18'>#</a>
            </div>
            <p>データセットのサイズ</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">120</span>    <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-19'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-19'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">122</span>        <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-20'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-20'>#</a>
            </div>
            <p>サンプルを入手</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">124</span>    <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-21'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-21'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">126</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">idx</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-22'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-22'>#</a>
            </div>
            <h2>2 変量ガウス混合物</h2>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span></span></span></span></span>混合物はおよびで表されます<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathcal" style="margin-right:0.14736em">N</span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcc" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcc" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="">ρ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span>。このクラスは温度を調整し、パラメーターからカテゴリ分布とガウス分布を作成します</p>。

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">129</span><span class="k">class</span> <span class="nc">BivariateGaussianMixture</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-23'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-23'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">139</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pi_logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mu_x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mu_y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">140</span>                 <span class="n">sigma_x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sigma_y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">rho_xy</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">141</span>        <span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span> <span class="o">=</span> <span class="n">pi_logits</span>
<span class="lineno">142</span>        <span class="bp">self</span><span class="o">.</span><span class="n">mu_x</span> <span class="o">=</span> <span class="n">mu_x</span>
<span class="lineno">143</span>        <span class="bp">self</span><span class="o">.</span><span class="n">mu_y</span> <span class="o">=</span> <span class="n">mu_y</span>
<span class="lineno">144</span>        <span class="bp">self</span><span class="o">.</span><span class="n">sigma_x</span> <span class="o">=</span> <span class="n">sigma_x</span>
<span class="lineno">145</span>        <span class="bp">self</span><span class="o">.</span><span class="n">sigma_y</span> <span class="o">=</span> <span class="n">sigma_y</span>
<span class="lineno">146</span>        <span class="bp">self</span><span class="o">.</span><span class="n">rho_xy</span> <span class="o">=</span> <span class="n">rho_xy</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-24'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-24'>#</a>
            </div>
            <p>混合物中の分布の数、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcj" style=""><span class="mord mathnormal" style="margin-right:0.10903em">M</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">148</span>    <span class="nd">@property</span>
<span class="lineno">149</span>    <span class="k">def</span> <span class="nf">n_distributions</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-25'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-25'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">151</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-26'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-26'>#</a>
            </div>
            <p>温度による調整 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqbx" style=""><span class="mord mathnormal" style="margin-right:0.1132em">τ</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">153</span>    <span class="k">def</span> <span class="nf">set_temperature</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">temperature</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-27'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-27'>#</a>
            </div>
            <p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.0967699999999998em;vertical-align:-0.15em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-3.25233em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.30977em;vertical-align:-0.686em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.62377em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqbx" style=""><span class="mord mathnormal" style="margin-right:0.1132em">τ</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-3.25233em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">158</span>        <span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span> <span class="o">/=</span> <span class="n">temperature</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-28'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-28'>#</a>
            </div>
            <p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.1111079999999998em;vertical-align:-0.247em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8641079999999999em;"><span style="top:-2.4530000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">x</span></span></span><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.1111079999999998em;vertical-align:-0.247em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8641079999999999em;"><span style="top:-2.4530000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">x</span></span></span><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span><span class="mord coloredeq eqbx" style=""><span class="mord mathnormal" style="margin-right:0.1132em">τ</span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">160</span>        <span class="bp">self</span><span class="o">.</span><span class="n">sigma_x</span> <span class="o">*=</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">temperature</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-29'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-29'>#</a>
            </div>
            <p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.2472159999999999em;vertical-align:-0.383108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.864108em;"><span style="top:-2.4530000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span></span></span><span style="top:-3.1130000000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.383108em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.2472159999999999em;vertical-align:-0.383108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.864108em;"><span style="top:-2.4530000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span></span></span><span style="top:-3.1130000000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.383108em;"><span></span></span></span></span></span></span><span class="mord coloredeq eqbx" style=""><span class="mord mathnormal" style="margin-right:0.1132em">τ</span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">162</span>        <span class="bp">self</span><span class="o">.</span><span class="n">sigma_y</span> <span class="o">*=</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">temperature</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-30'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-30'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">164</span>    <span class="k">def</span> <span class="nf">get_distribution</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-31'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-31'>#</a>
            </div>
            <p>クランプ<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbo" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.716668em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span>、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.716668em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbl" style=""><span class="mord" style=""><span class="mord mathnormal" style="">ρ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> <code  class="highlight"><span></span><span class="n">NaN</span></code>
 Sが入らないように</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">166</span>        <span class="n">sigma_x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp_min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sigma_x</span><span class="p">,</span> <span class="mf">1e-5</span><span class="p">)</span>
<span class="lineno">167</span>        <span class="n">sigma_y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp_min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sigma_y</span><span class="p">,</span> <span class="mf">1e-5</span><span class="p">)</span>
<span class="lineno">168</span>        <span class="n">rho_xy</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">rho_xy</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span> <span class="o">+</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="mi">1</span> <span class="o">-</span> <span class="mf">1e-5</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-32'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-32'>#</a>
            </div>
            <p>手段を取得</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">171</span>        <span class="n">mean</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">mu_x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mu_y</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-33'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-33'>#</a>
            </div>
            <p>共分散行列を取得</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">173</span>        <span class="n">cov</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span>
<span class="lineno">174</span>            <span class="n">sigma_x</span> <span class="o">*</span> <span class="n">sigma_x</span><span class="p">,</span> <span class="n">rho_xy</span> <span class="o">*</span> <span class="n">sigma_x</span> <span class="o">*</span> <span class="n">sigma_y</span><span class="p">,</span>
<span class="lineno">175</span>            <span class="n">rho_xy</span> <span class="o">*</span> <span class="n">sigma_x</span> <span class="o">*</span> <span class="n">sigma_y</span><span class="p">,</span> <span class="n">sigma_y</span> <span class="o">*</span> <span class="n">sigma_y</span>
<span class="lineno">176</span>        <span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">177</span>        <span class="n">cov</span> <span class="o">=</span> <span class="n">cov</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">sigma_y</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-34'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-34'>#</a>
            </div>
            <p>2 変量正規分布を作成します。</p>
<p>📝 where <code  class="highlight"><span></span><span class="n">scale_tril</span></code>
 をマトリックス化すると効率的です<span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal">a</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbo" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal">b</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.716668em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqbl" style=""><span class="mord" style=""><span class="mord mathnormal" style="">ρ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqbp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal">c</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.84em;vertical-align:-0.6276250000000001em;"></span><span class="mord coloredeq eqbp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.212375em;"><span class="svg-align" style="top:-3.8em;"><span class="pstrut" style="height:3.8em;"></span><span class="mord" style="padding-left:1em;"><span class="mord coloredeq eqci" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord mathnormal">ρ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7401079999999999em;"><span style="top:-2.4530000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span></span></span></span><span style="top:-2.989em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.383108em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.1723749999999997em;"><span class="pstrut" style="height:3.8em;"></span><span class="hide-tail" style="min-width:1.02em;height:1.8800000000000001em;"><svg height="1.8800000000000001em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1944" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M983 90
l0 -0
c4,-6.7,10,-10,18,-10 H400000v40
H1013.1s-83.4,268,-264.1,840c-180.7,572,-277,876.3,-289,913c-4.7,4.7,-12.7,7,-24,7
s-12,0,-12,0c-1.3,-3.3,-3.7,-11.7,-7,-25c-35.3,-125.3,-106.7,-373.3,-214,-744
c-10,12,-21,25,-33,39s-32,39,-32,39c-6,-5.3,-15,-14,-27,-26s25,-30,25,-30
c26.7,-32.7,52,-63,76,-91s52,-60,52,-60s208,722,208,722
c56,-175.3,126.3,-397.3,211,-666c84.7,-268.7,153.8,-488.2,207.5,-658.5
c53.7,-170.3,84.5,-266.8,92.5,-289.5z
M1001 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.6276250000000001em;"><span></span></span></span></span></span></span></span></span></span></span>。<code  class="highlight"><span></span><span class="p">[[</span><span class="n">a</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">]]</span></code>
ただし、わかりやすくするために、共分散マトリックスを使用します。<a href="https://www2.stat.duke.edu/courses/Spring12/sta104.1/Lectures/Lec22.pdf">これは、二変量分布、それらの共分散行列、および確率密度関数について詳しく知りたい場合に役立つリソースです</a></p>。

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">188</span>        <span class="n">multi_dist</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">MultivariateNormal</span><span class="p">(</span><span class="n">mean</span><span class="p">,</span> <span class="n">covariance_matrix</span><span class="o">=</span><span class="n">cov</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-35'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-35'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span></span></span></span></span>ロジットからカテゴリ分布を作成</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">191</span>        <span class="n">cat_dist</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-36'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-36'>#</a>
            </div>
            <p></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">194</span>        <span class="k">return</span> <span class="n">cat_dist</span><span class="p">,</span> <span class="n">multi_dist</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-37'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-37'>#</a>
            </div>
            <h2>エンコーダモジュール</h2>
<p>これは双方向の LSTM で構成されています。</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">197</span><span class="k">class</span> <span class="nc">EncoderRNN</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-38'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-38'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">204</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_z</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">enc_hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">205</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-39'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-39'>#</a>
            </div>
            <p>のシーケンスを入力として双方向 LSTM を作成します。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqt" style=""><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord coloredeq eqba" style=""><span class="mord coloredeq eqbc" style=""><span class="mord coloredeq eqbm" style="">Δ</span><span class="mord mathnormal coloredeq eqbm" style="">x</span></span><span class="mpunct coloredeq eqbc" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbc" style="">Δ</span></span><span class="mord mathnormal coloredeq eqba" style="margin-right:0.03588em">y</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqcd" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqci" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct coloredeq eqbg" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct coloredeq eqbg" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqcf" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">208</span>        <span class="bp">self</span><span class="o">.</span><span class="n">lstm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LSTM</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">enc_hidden_size</span><span class="p">,</span> <span class="n">bidirectional</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-40'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-40'>#</a>
            </div>
            <p>さっそくゲットしよう <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">μ</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">210</span>        <span class="bp">self</span><span class="o">.</span><span class="n">mu_head</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">enc_hidden_size</span><span class="p">,</span> <span class="n">d_z</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-41'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-41'>#</a>
            </div>
            <p>さっそくゲットしよう <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqbi" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.03588em">σ</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord" style="">^</span></span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">212</span>        <span class="bp">self</span><span class="o">.</span><span class="n">sigma_head</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">enc_hidden_size</span><span class="p">,</span> <span class="n">d_z</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-42'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-42'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">214</span>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">state</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-43'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-43'>#</a>
            </div>
            <p>双方向 LSTM の隠れ状態は、最後のトークンの出力を順方向に、最初のトークンの出力を逆方向に連結したものです。これが私たちが望むことです。<span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">h</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.10680899999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mrel mtight">→</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">e</span><span class="mord mathnormal">n</span><span class="mord mathnormal">co</span><span class="mord mathnormal">d</span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.10680899999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mrel mtight">→</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.05764em;">S</span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal">h</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.10680899999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mrel mtight">←</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord mathnormal">e</span><span class="mord mathnormal">n</span><span class="mord mathnormal">co</span><span class="mord mathnormal">d</span><span class="mord mathnormal">e</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"><span class="mrel">←</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.10680899999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mrel mtight">←</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05764em;">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">re</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">v</span><span class="mord mathnormal mtight">erse</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="">h</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mopen" style="">[</span><span class="mord" style=""><span class="mord mathnormal" style="">h</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.10680899999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mrel mtight" style="">→</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">;</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">h</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.10680899999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mrel mtight" style="">←</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">]</span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">221</span>        <span class="n">_</span><span class="p">,</span> <span class="p">(</span><span class="n">hidden</span><span class="p">,</span> <span class="n">cell</span><span class="p">)</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lstm</span><span class="p">(</span><span class="n">inputs</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">state</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-44'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-44'>#</a>
            </div>
            <p>状態には形があり<code  class="highlight"><span></span><span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">]</span></code>
、最初の次元が方向です。それを再配置して <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="">h</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mopen" style="">[</span><span class="mord" style=""><span class="mord mathnormal" style="">h</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.10680899999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mrel mtight" style="">→</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">;</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">h</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.10680899999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mrel mtight" style="">←</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">]</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">225</span>        <span class="n">hidden</span> <span class="o">=</span> <span class="n">einops</span><span class="o">.</span><span class="n">rearrange</span><span class="p">(</span><span class="n">hidden</span><span class="p">,</span> <span class="s1">&#39;fb b h -&gt; b (fb h)&#39;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-45'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-45'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">μ</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">228</span>        <span class="n">mu</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mu_head</span><span class="p">(</span><span class="n">hidden</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-46'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-46'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqbi" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.03588em">σ</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord" style="">^</span></span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">230</span>        <span class="n">sigma_hat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigma_head</span><span class="p">(</span><span class="n">hidden</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-47'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-47'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.2251079999999999em;vertical-align:-0.345em;"></span><span class="mop">exp</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8801079999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqbi" style=""><span class="mord accent mtight" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-2.7em;"><span class="pstrut" style="height:2.7em;"></span><span class="mord mathnormal mtight" style="margin-right:0.03588em">σ</span></span><span style="top:-2.7em;"><span class="pstrut" style="height:2.7em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord mtight" style="">^</span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">232</span>        <span class="n">sigma</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">sigma_hat</span> <span class="o">/</span> <span class="mf">2.</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-48'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-48'>#</a>
            </div>
            <p>[サンプル] <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqcl" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.7777700000000001em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">μ</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.44445em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mopen">(</span><span class="mord coloredeq eqch" style=""><span class="mord" style="">0</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="mclose">)</span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">235</span>        <span class="n">z</span> <span class="o">=</span> <span class="n">mu</span> <span class="o">+</span> <span class="n">sigma</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">mu</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">mu</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="n">mu</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">mu</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-49'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-49'>#</a>
            </div>
            <p></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">238</span>        <span class="k">return</span> <span class="n">z</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">sigma_hat</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-50'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-50'>#</a>
            </div>
            <h2>デコーダモジュール</h2>
<p>これはLSTMで構成されています</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">241</span><span class="k">class</span> <span class="nc">DecoderRNN</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-51'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-51'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">248</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_z</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dec_hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_distributions</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">249</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-52'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-52'>#</a>
            </div>
            <p>LSTM は入力として取ります <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqn" style=""><span class="mopen" style="">[</span><span class="mord" style=""><span class="mopen coloredeq eqt" style="">(</span><span class="mord coloredeq eqt" style=""><span class="mord coloredeq eqba" style=""><span class="mord coloredeq eqbc" style=""><span class="mord coloredeq eqbm" style="">Δ</span><span class="mord mathnormal coloredeq eqbm" style="">x</span></span><span class="mpunct coloredeq eqbc" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbc" style="">Δ</span></span><span class="mord mathnormal coloredeq eqba" style="margin-right:0.03588em">y</span></span><span class="mpunct coloredeq eqt" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqt" style=""><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqcd" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqci" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct coloredeq eqbg" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct coloredeq eqbg" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqcf" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mclose coloredeq eqt" style="">)</span></span><span class="mpunct" style="">;</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcl" style="margin-right:0.04398em">z</span></span><span class="mclose" style="">]</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">251</span>        <span class="bp">self</span><span class="o">.</span><span class="n">lstm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LSTM</span><span class="p">(</span><span class="n">d_z</span> <span class="o">+</span> <span class="mi">5</span><span class="p">,</span> <span class="n">dec_hidden_size</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-53'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-53'>#</a>
            </div>
            <p>LSTM の初期状態はです。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqw" style=""><span class="mopen" style="">[</span><span class="mord" style=""><span class="mord mathnormal" style="">h</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqch" style="">0</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">;</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">c</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqch" style="">0</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">]</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop" style=""><span style="">t</span><span style="">a</span><span style="">n</span><span style="">h</span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcl" style="margin-right:0.04398em">z</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcl" style="margin-right:0.04398em">z</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcl" style="margin-right:0.04398em">z</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span><code  class="highlight"><span></span><span class="n">init_state</span></code>
これの線形変換は</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">255</span>        <span class="bp">self</span><span class="o">.</span><span class="n">init_state</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_z</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">dec_hidden_size</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-54'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-54'>#</a>
            </div>
            <p>このレイヤーは、それぞれの出力を生成します<code  class="highlight"><span></span><span class="n">n_distributions</span></code>
。各分布には 6 つのパラメータが必要です</p>。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.232878em;vertical-align:-0.286108em;"></span><span class="mopen">(</span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-3.25233em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139199999999997em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2501em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.03588em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139199999999997em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2501em;"><span></span></span></span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2501em;"><span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.03588em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">ρ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.03588em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">260</span>        <span class="bp">self</span><span class="o">.</span><span class="n">mixtures</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dec_hidden_size</span><span class="p">,</span> <span class="mi">6</span> <span class="o">*</span> <span class="n">n_distributions</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-55'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-55'>#</a>
            </div>
            <p>このヘッドはロジット用です <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqck" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqci" style=""><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.19444em;"><span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqck" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.19444em;"><span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqck" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.19444em;"><span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">263</span>        <span class="bp">self</span><span class="o">.</span><span class="n">q_head</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dec_hidden_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-56'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-56'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mopen">(</span><span class="mord"><span class="mord coloredeq eqck" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>これは場所を計算するためです <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord coloredeq eqck" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop"><span class="mord mathrm">softmax</span></span><span class="mopen">(</span><span class="mord coloredeq eqbr" style=""><span class="mord accent" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqck" style="margin-right:0.03588em">q</span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.19444em;"><span></span></span></span></span></span></span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.706826em;vertical-align:-1.279826em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.427em;"><span style="top:-2.1559920000000004em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.954008em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span><span class="mrel mtight">=</span><span class="mord mtight coloredeq eqci" style=""><span class="mord mtight" style="">1</span></span></span></span></span><span style="top:-3.2029em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">exp</span><span class="mopen">(</span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqck" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span><span class="mclose">)</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop">exp</span><span class="mopen">(</span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqck" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.19444em;"><span></span></span></span></span></span><span class="mclose">)</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.279826em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">266</span>        <span class="bp">self</span><span class="o">.</span><span class="n">q_log_softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LogSoftmax</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-57'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-57'>#</a>
            </div>
            <p>これらのパラメータは、後で参照できるように保存されます</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">269</span>        <span class="bp">self</span><span class="o">.</span><span class="n">n_distributions</span> <span class="o">=</span> <span class="n">n_distributions</span>
<span class="lineno">270</span>        <span class="bp">self</span><span class="o">.</span><span class="n">dec_hidden_size</span> <span class="o">=</span> <span class="n">dec_hidden_size</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-58'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-58'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">272</span>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">z</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]]):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-59'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-59'>#</a>
            </div>
            <p>初期状態の計算</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">274</span>        <span class="k">if</span> <span class="n">state</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-60'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-60'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqw" style=""><span class="mopen" style="">[</span><span class="mord" style=""><span class="mord mathnormal" style="">h</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqch" style="">0</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">;</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">c</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqch" style="">0</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">]</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop" style=""><span style="">t</span><span style="">a</span><span style="">n</span><span style="">h</span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcl" style="margin-right:0.04398em">z</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcl" style="margin-right:0.04398em">z</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcl" style="margin-right:0.04398em">z</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">276</span>            <span class="n">h</span><span class="p">,</span> <span class="n">c</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">init_state</span><span class="p">(</span><span class="n">z</span><span class="p">)),</span> <span class="bp">self</span><span class="o">.</span><span class="n">dec_hidden_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-61'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-61'>#</a>
            </div>
            <p><code  class="highlight"><span></span><span class="n">h</span></code>
<code  class="highlight"><span></span><span class="n">c</span></code>
そして形があります<code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">lstm_size</span><span class="p">]</span></code>
。<code  class="highlight"><span></span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">lstm_size</span><span class="p">]</span></code>
LSTMで使われている形なので、形を整えたいのです</p>。

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">279</span>            <span class="n">state</span> <span class="o">=</span> <span class="p">(</span><span class="n">h</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">(),</span> <span class="n">c</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">())</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-62'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-62'>#</a>
            </div>
            <p>LSTM を実行してください</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">282</span>        <span class="n">outputs</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lstm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">state</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-63'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-63'>#</a>
            </div>
            <p>取得 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mopen">(</span><span class="mord coloredeq eqck" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span></span><span class="mclose">)</span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">285</span>        <span class="n">q_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">q_log_softmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">q_head</span><span class="p">(</span><span class="n">outputs</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-64'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-64'>#</a>
            </div>
            <p>取得<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.232878em;vertical-align:-0.286108em;"></span><span class="mopen">(</span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-3.25233em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">ρ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>。<code  class="highlight"><span></span><span class="n">torch</span><span class="o">.</span><span class="n">split</span></code>
<code  class="highlight"><span></span><span class="bp">self</span><span class="o">.</span><span class="n">n_distribution</span></code>
出力を次元全体のサイズの6つのテンソルに分割します</p>。<code  class="highlight"><span></span><span class="mi">2</span></code>


        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">291</span>        <span class="n">pi_logits</span><span class="p">,</span> <span class="n">mu_x</span><span class="p">,</span> <span class="n">mu_y</span><span class="p">,</span> <span class="n">sigma_x</span><span class="p">,</span> <span class="n">sigma_y</span><span class="p">,</span> <span class="n">rho_xy</span> <span class="o">=</span> \
<span class="lineno">292</span>            <span class="n">torch</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mixtures</span><span class="p">(</span><span class="n">outputs</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_distributions</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-65'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-65'>#</a>
            </div>
            <p>2 変量ガウス混合物の作成と場所と <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span></span></span></span></span> <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathcal" style="margin-right:0.14736em">N</span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcc" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcc" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="">ρ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span> <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.716668em;vertical-align:-0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mop">exp</span><span class="mopen">(</span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mop">exp</span><span class="mopen">(</span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal">ρ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mop">tanh</span><span class="mopen">(</span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">ρ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></span> <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.19677em;vertical-align:-0.25em;"></span><span class="mop"><span class="mord mathrm">softmax</span></span><span class="mopen">(</span><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span></span><span style="top:-3.25233em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span></span></span></span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.903596em;vertical-align:-1.279826em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.62377em;"><span style="top:-2.1559920000000004em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.954008em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span><span class="mrel mtight">=</span><span class="mord mtight coloredeq eqci" style=""><span class="mord mtight" style="">1</span></span></span></span></span><span style="top:-3.2029em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">exp</span><span class="mopen">(</span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span style="top:-3.25233em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span><span class="mclose">)</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop">exp</span><span class="mopen">(</span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-3.25233em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span><span class="mclose">)</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.279826em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span></p>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span></span></span></span></span>は、混合から分布を選択するカテゴリ確率です。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathcal" style="margin-right:0.14736em">N</span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcc" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcc" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="">ρ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">305</span>        <span class="n">dist</span> <span class="o">=</span> <span class="n">BivariateGaussianMixture</span><span class="p">(</span><span class="n">pi_logits</span><span class="p">,</span> <span class="n">mu_x</span><span class="p">,</span> <span class="n">mu_y</span><span class="p">,</span>
<span class="lineno">306</span>                                        <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">sigma_x</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">sigma_y</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">rho_xy</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-66'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-66'>#</a>
            </div>
            <p></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">309</span>        <span class="k">return</span> <span class="n">dist</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">,</span> <span class="n">state</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-67'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-67'>#</a>
            </div>
            <h2>復興損失</h2>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">312</span><span class="k">class</span> <span class="nc">ReconstructionLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-68'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-68'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">317</span>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">318</span>                 <span class="n">dist</span><span class="p">:</span> <span class="s1">&#39;BivariateGaussianMixture&#39;</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-69'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-69'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span></span></span></span></span>取得して <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathcal" style="margin-right:0.14736em">N</span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcc" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcc" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="">ρ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">320</span>        <span class="n">pi</span><span class="p">,</span> <span class="n">mix</span> <span class="o">=</span> <span class="n">dist</span><span class="o">.</span><span class="n">get_distribution</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-70'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-70'>#</a>
            </div>
            <p><code  class="highlight"><span></span><span class="n">target</span></code>
<code  class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="mi">5</span><span class="p">]</span></code>
最後の次元がフィーチャであるような形をしています<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqt" style=""><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord coloredeq eqba" style=""><span class="mord coloredeq eqbc" style=""><span class="mord coloredeq eqbm" style="">Δ</span><span class="mord mathnormal coloredeq eqbm" style="">x</span></span><span class="mpunct coloredeq eqbc" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbc" style="">Δ</span></span><span class="mord mathnormal coloredeq eqba" style="margin-right:0.03588em">y</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqcd" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqci" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct coloredeq eqbg" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct coloredeq eqbg" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqcf" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span>。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8777699999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbc" style=""><span class="mord" style=""><span class="mord coloredeq eqbm" style="">Δ</span><span class="mord mathnormal coloredeq eqbm" style="">x</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style="">Δ</span></span></span></span></span></span>y を取得して、混合内の各分布から確率を求めたいと思います</p>。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathcal" style="margin-right:0.14736em">N</span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcc" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcc" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="">ρ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span>
<p><code  class="highlight"><span></span><span class="n">xy</span></code>
形になります <code  class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">n_distributions</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span></code>
</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">327</span>        <span class="n">xy</span> <span class="o">=</span> <span class="n">target</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">dist</span><span class="o">.</span><span class="n">n_distributions</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-71'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-71'>#</a>
            </div>
            <p>確率の計算 <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqy" style=""><span class="mord mathnormal" style="">p</span><span class="mord" style=""><span class="mopen coloredeq eqz" style="">(</span><span class="mord coloredeq eqz" style=""><span class="mord coloredeq eqba" style=""><span class="mord coloredeq eqbc" style=""><span class="mord coloredeq eqbm" style="">Δ</span><span class="mord mathnormal coloredeq eqbm" style="">x</span></span><span class="mpunct coloredeq eqbc" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbc" style="">Δ</span></span><span class="mord mathnormal coloredeq eqba" style="margin-right:0.03588em">y</span></span><span class="mclose coloredeq eqz" style="">)</span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:3.2421130000000007em;vertical-align:-1.4137769999999998em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.8283360000000006em;"><span style="top:-1.872331em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span><span class="mrel mtight">=</span><span class="mord mtight coloredeq eqci" style=""><span class="mord mtight" style="">1</span></span></span></span></span><span style="top:-3.050005em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span><span style="top:-4.3000050000000005em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqcj" style=""><span class="mord mathnormal mtight" style="margin-right:0.10903em">M</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.4137769999999998em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mord"><span class="delimsizing size1">(</span></span><span class="mord coloredeq eqba" style=""><span class="mord" style=""><span class="mord coloredeq eqbc" style=""><span class="mord coloredeq eqbm" style="">Δ</span><span class="mord mathnormal coloredeq eqbm" style="">x</span></span><span class="mpunct coloredeq eqbc" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbc" style="">Δ</span></span><span class="mord mathnormal" style="margin-right:0.03588em">y</span></span><span class="mord">∣</span><span class="mord"><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal">ρ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size1">)</span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">333</span>        <span class="n">probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">pi</span><span class="o">.</span><span class="n">probs</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">mix</span><span class="o">.</span><span class="n">log_prob</span><span class="p">(</span><span class="n">xy</span><span class="p">)),</span> <span class="mi">2</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-72'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-72'>#</a>
            </div>
            <p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">s</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:3.1171050000000005em;vertical-align:-1.277669em;"></span><span class="mord">−</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.3139999999999996em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqbq" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">ma</span><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqci" style=""><span class="mord" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.8360000000000001em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.8394360000000003em;"><span style="top:-1.872331em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mrel mtight">=</span><span class="mord mtight coloredeq eqci" style=""><span class="mord mtight" style="">1</span></span></span></span></span><span style="top:-3.050005em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span><span style="top:-4.311105em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqca" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16454285714285719em;"><span style="top:-2.357em;margin-left:-0.10903em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="">s</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.277669em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="delimsizing size1">(</span></span><span class="mord coloredeq eqy" style=""><span class="mord mathnormal" style="">p</span><span class="mord" style=""><span class="mopen coloredeq eqz" style="">(</span><span class="mord coloredeq eqz" style=""><span class="mord coloredeq eqba" style=""><span class="mord coloredeq eqbc" style=""><span class="mord coloredeq eqbm" style="">Δ</span><span class="mord mathnormal coloredeq eqbm" style="">x</span></span><span class="mpunct coloredeq eqbc" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbc" style="">Δ</span></span><span class="mord mathnormal coloredeq eqba" style="margin-right:0.03588em">y</span></span><span class="mclose coloredeq eqz" style="">)</span></span></span><span class="mord"><span class="delimsizing size1">)</span></span></span></span></span></span></span><code  class="highlight"><span></span><span class="n">probs</span></code>
要素には <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbq" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">ma</span><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> (<code  class="highlight"><span></span><span class="n">longest_seq_len</span></code>
) がありますが、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqca" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">s</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>残りはマスクされているので、合計が計算されるだけです。</p>
<p>合計を取って除算しないで割る必要があるように感じるかもしれませんが、これによって<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqca" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">s</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbq" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">ma</span><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>、短いシーケンスで個々の予測をより重要視できるようになります。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqy" style=""><span class="mord mathnormal" style="">p</span><span class="mord" style=""><span class="mopen coloredeq eqz" style="">(</span><span class="mord coloredeq eqz" style=""><span class="mord coloredeq eqba" style=""><span class="mord coloredeq eqbc" style=""><span class="mord coloredeq eqbm" style="">Δ</span><span class="mord mathnormal coloredeq eqbm" style="">x</span></span><span class="mpunct coloredeq eqbc" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbc" style="">Δ</span></span><span class="mord mathnormal coloredeq eqba" style="margin-right:0.03588em">y</span></span><span class="mclose coloredeq eqz" style="">)</span></span></span></span></span></span></span>で割ると、各予測に同じ重みを与えます</p>。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbq" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">ma</span><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">342</span>        <span class="n">loss_stroke</span> <span class="o">=</span> <span class="o">-</span><span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">mask</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="mf">1e-5</span> <span class="o">+</span> <span class="n">probs</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-73'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-73'>#</a>
            </div>
            <p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord"><span class="mord mathnormal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">p</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:3.1415490000000004em;vertical-align:-1.302113em;"></span><span class="mord">−</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.3139999999999996em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqbq" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">ma</span><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqci" style=""><span class="mord" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.8360000000000001em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.8394360000000003em;"><span style="top:-1.872331em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mrel mtight">=</span><span class="mord mtight coloredeq eqci" style=""><span class="mord mtight" style="">1</span></span></span></span></span><span style="top:-3.050005em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span><span style="top:-4.311105em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqbq" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16454285714285719em;"><span style="top:-2.357em;margin-left:-0.10903em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">ma</span><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.277669em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.801113em;"><span style="top:-1.8478869999999998em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span><span class="mrel mtight">=</span><span class="mord mtight coloredeq eqci" style=""><span class="mord mtight" style="">1</span></span></span></span></span><span style="top:-3.0500049999999996em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span><span style="top:-4.300005em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">3</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.302113em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361079999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mopen">(</span><span class="mord"><span class="mord coloredeq eqck" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361079999999999em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">345</span>        <span class="n">loss_pen</span> <span class="o">=</span> <span class="o">-</span><span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">target</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">*</span> <span class="n">q_logits</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-74'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-74'>#</a>
            </div>
            <p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbz" style=""><span class="mord" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.32833099999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.00773em">R</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">s</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord"><span class="mord mathnormal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">p</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">348</span>        <span class="k">return</span> <span class="n">loss_stroke</span> <span class="o">+</span> <span class="n">loss_pen</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-75'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-75'>#</a>
            </div>
            <h2>KL-ダイバージェンスロス</h2>
<p>これは、特定の正規分布との KL ダイバージェンスを計算します。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mopen">(</span><span class="mord coloredeq eqch" style=""><span class="mord" style="">0</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style="">1</span></span><span class="mclose">)</span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">351</span><span class="k">class</span> <span class="nc">KLDivLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-76'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-76'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">358</span>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sigma_hat</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mu</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-77'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-77'>#</a>
            </div>
            <p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbv" style=""><span class="mord" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.32833099999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span><span class="mord mathnormal mtight" style="">L</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.40003em;vertical-align:-0.95003em;"></span><span class="mord">−</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.3139999999999996em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">2</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqcl" style=""><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqci" style=""><span class="mord" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.8360000000000001em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord"><span class="delimsizing size3">(</span></span><span class="mord coloredeq eqci" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.77777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqbi" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.03588em">σ</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord" style="">^</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.0585479999999998em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">μ</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8641079999999999em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:2.40003em;vertical-align:-0.95003em;"></span><span class="mop">exp</span><span class="mopen">(</span><span class="mord coloredeq eqbi" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.03588em">σ</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord" style="">^</span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mord"><span class="delimsizing size3">)</span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">360</span>        <span class="k">return</span> <span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">sigma_hat</span> <span class="o">-</span> <span class="n">mu</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">sigma_hat</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-78'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-78'>#</a>
            </div>
            <h2>サンプラー</h2>
<p>デコーダーからスケッチをサンプリングしてプロットします。</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">363</span><span class="k">class</span> <span class="nc">Sampler</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-79'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-79'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">370</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">EncoderRNN</span><span class="p">,</span> <span class="n">decoder</span><span class="p">:</span> <span class="n">DecoderRNN</span><span class="p">):</span>
<span class="lineno">371</span>        <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">decoder</span>
<span class="lineno">372</span>        <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-80'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-80'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">374</span>    <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">temperature</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-81'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-81'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbq" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">ma</span><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">376</span>        <span class="n">longest_seq_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-82'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-82'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqcl" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span></span>エンコーダから取得</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">379</span>        <span class="n">z</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-83'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-83'>#</a>
            </div>
            <p>シーケンスの開始ストロークは <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqbd" style=""><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord coloredeq eqch" style="">0</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqch" style="">0</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqci" style="">1</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqch" style="">0</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqch" style="">0</span></span><span class="mclose" style="">)</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">382</span>        <span class="n">s</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">new_tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span>
<span class="lineno">383</span>        <span class="n">seq</span> <span class="o">=</span> <span class="p">[</span><span class="n">s</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-84'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-84'>#</a>
            </div>
            <p>初期デコーダーは<code  class="highlight"><span></span><span class="kc">None</span></code>
.デコーダーはそれを次のように初期化します <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqw" style=""><span class="mopen" style="">[</span><span class="mord" style=""><span class="mord mathnormal" style="">h</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqch" style="">0</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">;</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">c</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqch" style="">0</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">]</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop" style=""><span style="">t</span><span style="">a</span><span style="">n</span><span style="">h</span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcl" style="margin-right:0.04398em">z</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcl" style="margin-right:0.04398em">z</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqcl" style="margin-right:0.04398em">z</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">386</span>        <span class="n">state</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-85'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-85'>#</a>
            </div>
            <p>グラデーションはいらない</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">389</span>        <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-86'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-86'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbq" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">ma</span><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>サンプルストローク</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">391</span>            <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">longest_seq_len</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-87'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-87'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqn" style=""><span class="mopen" style="">[</span><span class="mord" style=""><span class="mopen coloredeq eqt" style="">(</span><span class="mord coloredeq eqt" style=""><span class="mord coloredeq eqba" style=""><span class="mord coloredeq eqbc" style=""><span class="mord coloredeq eqbm" style="">Δ</span><span class="mord mathnormal coloredeq eqbm" style="">x</span></span><span class="mpunct coloredeq eqbc" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbc" style="">Δ</span></span><span class="mord mathnormal coloredeq eqba" style="margin-right:0.03588em">y</span></span><span class="mpunct coloredeq eqt" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqt" style=""><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqcd" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqci" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct coloredeq eqbg" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct coloredeq eqbg" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqcf" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mclose coloredeq eqt" style="">)</span></span><span class="mpunct" style="">;</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcl" style="margin-right:0.04398em">z</span></span><span class="mclose" style="">]</span></span></span></span></span></span>デコーダーへの入力です</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">393</span>                <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">s</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">z</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)],</span> <span class="mi">2</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-88'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-88'>#</a>
            </div>
            <p>デコーダーから<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathcal" style="margin-right:0.14736em">N</span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcc" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcc" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="">ρ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span>、、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqck" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span></span></span></span></span></span>および次の状態を取得</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">396</span>                <span class="n">dist</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">state</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-89'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-89'>#</a>
            </div>
            <p>ストロークをサンプリングする</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">398</span>                <span class="n">s</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sample_step</span><span class="p">(</span><span class="n">dist</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">,</span> <span class="n">temperature</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-90'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-90'>#</a>
            </div>
            <p>新しいストロークをストロークのシーケンスに追加します</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">400</span>                <span class="n">seq</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">s</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-91'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-91'>#</a>
            </div>
            <p>もしそうなら、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcf" style=""><span class="mord" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style="">1</span></span></span></span></span></span>サンプリングを停止してください。これはスケッチが停止したことを示します</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">402</span>                <span class="k">if</span> <span class="n">s</span><span class="p">[</span><span class="mi">4</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="lineno">403</span>                    <span class="k">break</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-92'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-92'>#</a>
            </div>
            <p>一連のストロークの PyTorch テンソルを作成します。</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">406</span>        <span class="n">seq</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-93'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-93'>#</a>
            </div>
            <p>ストロークのシーケンスをプロット</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">409</span>        <span class="bp">self</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-94'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-94'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">411</span>    <span class="nd">@staticmethod</span>
<span class="lineno">412</span>    <span class="k">def</span> <span class="nf">_sample_step</span><span class="p">(</span><span class="n">dist</span><span class="p">:</span> <span class="s1">&#39;BivariateGaussianMixture&#39;</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">temperature</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-95'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-95'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqbx" style=""><span class="mord mathnormal" style="margin-right:0.1132em">τ</span></span></span></span></span></span>サンプリングの温度を設定します。これはクラスで実装されています<code  class="highlight"><span></span><span class="n">BivariateGaussianMixture</span></code>
。</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">414</span>        <span class="n">dist</span><span class="o">.</span><span class="n">set_temperature</span><span class="p">(</span><span class="n">temperature</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-96'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-96'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span></span></span></span></span>温度調整して <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathcal" style="margin-right:0.14736em">N</span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcc" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcc" style="">μ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="">ρ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span><span class="mord mathnormal mtight" style="margin-right:0.03588em">y</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">416</span>        <span class="n">pi</span><span class="p">,</span> <span class="n">mix</span> <span class="o">=</span> <span class="n">dist</span><span class="o">.</span><span class="n">get_distribution</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-97'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-97'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcb" style=""><span class="mord" style="">Π</span></span></span></span></span></span>分布の指標からサンプルを採取して、混合物から使用する</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">418</span>        <span class="n">idx</span> <span class="o">=</span> <span class="n">pi</span><span class="o">.</span><span class="n">sample</span><span class="p">()[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-98'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-98'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqck" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span></span></span></span></span></span>対数確率によるカテゴリ分布の作成または <code  class="highlight"><span></span><span class="n">q_logits</span></code>
 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord accent" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqck" style="margin-right:0.03588em">q</span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.19444em;"><span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">421</span>        <span class="n">q</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">q_logits</span> <span class="o">/</span> <span class="n">temperature</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-99'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-99'>#</a>
            </div>
            <p>からのサンプル <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqck" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">423</span>        <span class="n">q_idx</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">sample</span><span class="p">()[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-100'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-100'>#</a>
            </div>
            <p>混合物の正規分布からサンプリングし、次の式でインデックス付けされた分布を選択します。<code  class="highlight"><span></span><span class="n">idx</span></code>
</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">426</span>        <span class="n">xy</span> <span class="o">=</span> <span class="n">mix</span><span class="o">.</span><span class="n">sample</span><span class="p">()[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">idx</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-101'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-101'>#</a>
            </div>
            <p>空のストロークを作成 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord coloredeq eqba" style=""><span class="mord" style=""><span class="mord coloredeq eqbc" style=""><span class="mord coloredeq eqbm" style="">Δ</span><span class="mord mathnormal coloredeq eqbm" style="">x</span></span><span class="mpunct coloredeq eqbc" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbc" style="">Δ</span></span><span class="mord mathnormal" style="margin-right:0.03588em">y</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbh" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqck" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqci" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqck" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqck" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">429</span>        <span class="n">stroke</span> <span class="o">=</span> <span class="n">q_logits</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-102'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-102'>#</a>
            </div>
            <p>セット <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8777699999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqba" style=""><span class="mord" style=""><span class="mord coloredeq eqbc" style=""><span class="mord coloredeq eqbm" style="">Δ</span><span class="mord mathnormal coloredeq eqbm" style="">x</span></span><span class="mpunct coloredeq eqbc" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbc" style="">Δ</span></span><span class="mord mathnormal" style="margin-right:0.03588em">y</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">431</span>        <span class="n">stroke</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">xy</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-103'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-103'>#</a>
            </div>
            <p>セット <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbh" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqck" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqci" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqck" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqck" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">433</span>        <span class="n">stroke</span><span class="p">[</span><span class="n">q_idx</span> <span class="o">+</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-104'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-104'>#</a>
            </div>
            <p></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">435</span>        <span class="k">return</span> <span class="n">stroke</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-105'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-105'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">437</span>    <span class="nd">@staticmethod</span>
<span class="lineno">438</span>    <span class="k">def</span> <span class="nf">plot</span><span class="p">(</span><span class="n">seq</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-106'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-106'>#</a>
            </div>
            <p>の累積和を取ると、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqz" style=""><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord coloredeq eqba" style=""><span class="mord coloredeq eqbc" style=""><span class="mord coloredeq eqbm" style="">Δ</span><span class="mord mathnormal coloredeq eqbm" style="">x</span></span><span class="mpunct coloredeq eqbc" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbc" style="">Δ</span></span><span class="mord mathnormal coloredeq eqba" style="margin-right:0.03588em">y</span></span><span class="mclose" style="">)</span></span></span></span></span></span> <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mclose">)</span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">440</span>        <span class="n">seq</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">seq</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-107'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-107'>#</a>
            </div>
            <p>次の形式の新しいnumpy配列を作成します <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqck" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">442</span>        <span class="n">seq</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span>
<span class="lineno">443</span>        <span class="n">seq</span> <span class="o">=</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">3</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-108'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-108'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqci" style=""><span class="mord" style="">1</span></span></span></span></span></span>配列をあるポイントで分割します。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqck" style="margin-right:0.03588em">q</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>つまり、ペンを紙から持ち上げるポイントでストロークの配列を分割します。これにより、ストロークのシーケンスのリストが表示されます</p>。

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">448</span>        <span class="n">strokes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">seq</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">seq</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-109'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-109'>#</a>
            </div>
            <p>ストロークの各シーケンスをプロット</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">450</span>        <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">strokes</span><span class="p">:</span>
<span class="lineno">451</span>            <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">s</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="n">s</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-110'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-110'>#</a>
            </div>
            <p>座標軸を表示しないでください</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">453</span>        <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">&#39;off&#39;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-111'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-111'>#</a>
            </div>
            <p>プロットを表示</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">455</span>        <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-112'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-112'>#</a>
            </div>
            <h2>コンフィギュレーション</h2>
<p>これらはデフォルトの設定で、を渡すことで後で調整できます<code  class="highlight"><span></span><span class="nb">dict</span></code>
。</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">458</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">TrainValidConfigs</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-113'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-113'>#</a>
            </div>
            <p>実験を実行するデバイスを選択するためのデバイス構成</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">466</span>    <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">DeviceConfigs</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-114'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-114'>#</a>
            </div>
            <p></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">468</span>    <span class="n">encoder</span><span class="p">:</span> <span class="n">EncoderRNN</span>
<span class="lineno">469</span>    <span class="n">decoder</span><span class="p">:</span> <span class="n">DecoderRNN</span>
<span class="lineno">470</span>    <span class="n">optimizer</span><span class="p">:</span> <span class="n">optim</span><span class="o">.</span><span class="n">Adam</span>
<span class="lineno">471</span>    <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span>
<span class="lineno">472</span>
<span class="lineno">473</span>    <span class="n">dataset_name</span><span class="p">:</span> <span class="nb">str</span>
<span class="lineno">474</span>    <span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span>
<span class="lineno">475</span>    <span class="n">valid_loader</span><span class="p">:</span> <span class="n">DataLoader</span>
<span class="lineno">476</span>    <span class="n">train_dataset</span><span class="p">:</span> <span class="n">StrokesDataset</span>
<span class="lineno">477</span>    <span class="n">valid_dataset</span><span class="p">:</span> <span class="n">StrokesDataset</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-115'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-115'>#</a>
            </div>
            <p>エンコーダとデコーダのサイズ</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">480</span>    <span class="n">enc_hidden_size</span> <span class="o">=</span> <span class="mi">256</span>
<span class="lineno">481</span>    <span class="n">dec_hidden_size</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-116'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-116'>#</a>
            </div>
            <p>バッチサイズ</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">484</span>    <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">100</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-117'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-117'>#</a>
            </div>
            <p>のフィーチャ数 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqcl" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">487</span>    <span class="n">d_z</span> <span class="o">=</span> <span class="mi">128</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-118'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-118'>#</a>
            </div>
            <p>混合物中の分布の数、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcj" style=""><span class="mord mathnormal" style="margin-right:0.10903em">M</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">489</span>    <span class="n">n_distributions</span> <span class="o">=</span> <span class="mi">20</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-119'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-119'>#</a>
            </div>
            <p>KLダイバージェンスロスの重量、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbw" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.32833099999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.02691em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span><span class="mord mathnormal mtight" style="">L</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">492</span>    <span class="n">kl_div_loss_weight</span> <span class="o">=</span> <span class="mf">0.5</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-120'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-120'>#</a>
            </div>
            <p>グラデーションクリッピング</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">494</span>    <span class="n">grad_clip</span> <span class="o">=</span> <span class="mf">1.</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-121'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-121'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqbx" style=""><span class="mord mathnormal" style="margin-right:0.1132em">τ</span></span></span></span></span></span>サンプリング温度</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">496</span>    <span class="n">temperature</span> <span class="o">=</span> <span class="mf">0.4</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-122'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-122'>#</a>
            </div>
            <p>より長いストロークシーケンスを除外 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">2</span><span class="mord coloredeq eqch" style=""><span class="mord" style="">0</span></span><span class="mord coloredeq eqch" style=""><span class="mord" style="">0</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">499</span>    <span class="n">max_seq_length</span> <span class="o">=</span> <span class="mi">200</span>
<span class="lineno">500</span>
<span class="lineno">501</span>    <span class="n">epochs</span> <span class="o">=</span> <span class="mi">100</span>
<span class="lineno">502</span>
<span class="lineno">503</span>    <span class="n">kl_div_loss</span> <span class="o">=</span> <span class="n">KLDivLoss</span><span class="p">()</span>
<span class="lineno">504</span>    <span class="n">reconstruction_loss</span> <span class="o">=</span> <span class="n">ReconstructionLoss</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-123'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-123'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">506</span>    <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-124'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-124'>#</a>
            </div>
            <p>エンコーダとデコーダを初期化</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">508</span>        <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">EncoderRNN</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_z</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">enc_hidden_size</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">509</span>        <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">DecoderRNN</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_z</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dec_hidden_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_distributions</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-125'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-125'>#</a>
            </div>
            <p>オプティマイザを設定します。オプティマイザーのタイプや学習率などは設定可能です</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">512</span>        <span class="n">optimizer</span> <span class="o">=</span> <span class="n">OptimizerConfigs</span><span class="p">()</span>
<span class="lineno">513</span>        <span class="n">optimizer</span><span class="o">.</span><span class="n">parameters</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
<span class="lineno">514</span>        <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-126'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-126'>#</a>
            </div>
            <p>サンプラーの作成</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">517</span>        <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">Sampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-127'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-127'>#</a>
            </div>
            <p><code  class="highlight"><span></span><span class="n">npz</span></code>
ファイルパスは <code  class="highlight"><span></span><span class="n">data</span><span class="o">/</span><span class="n">sketch</span><span class="o">/</span><span class="p">[</span><span class="n">DATASET</span> <span class="n">NAME</span><span class="p">]</span><span class="o">.</span><span class="n">npz</span></code>
</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">520</span>        <span class="n">path</span> <span class="o">=</span> <span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;sketch&#39;</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_name</span><span class="si">}</span><span class="s1">.npz&#39;</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-128'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-128'>#</a>
            </div>
            <p>numpy ファイルを読み込む</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">522</span>        <span class="n">dataset</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">path</span><span class="p">),</span> <span class="n">encoding</span><span class="o">=</span><span class="s1">&#39;latin1&#39;</span><span class="p">,</span> <span class="n">allow_pickle</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-129'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-129'>#</a>
            </div>
            <p>トレーニングデータセットの作成</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">525</span>        <span class="bp">self</span><span class="o">.</span><span class="n">train_dataset</span> <span class="o">=</span> <span class="n">StrokesDataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">[</span><span class="s1">&#39;train&#39;</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-130'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-130'>#</a>
            </div>
            <p>検証データセットの作成</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">527</span>        <span class="bp">self</span><span class="o">.</span><span class="n">valid_dataset</span> <span class="o">=</span> <span class="n">StrokesDataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">[</span><span class="s1">&#39;valid&#39;</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_dataset</span><span class="o">.</span><span class="n">scale</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-131'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-131'>#</a>
            </div>
            <p>トレーニングデータローダーの作成</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">530</span>        <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_dataset</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-132'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-132'>#</a>
            </div>
            <p>検証データローダーの作成</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">532</span>        <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_dataset</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-133'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-133'>#</a>
            </div>
            <p>Tensorboardのレイヤー出力を監視するフックを追加</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">535</span>        <span class="n">hook_model_outputs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="s1">&#39;encoder&#39;</span><span class="p">)</span>
<span class="lineno">536</span>        <span class="n">hook_model_outputs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">,</span> <span class="s1">&#39;decoder&#39;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-134'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-134'>#</a>
            </div>
            <p>トレイン/検証ロスの合計を出力するようにトラッカーを設定</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">539</span>        <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">&quot;loss.total.*&quot;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">540</span>
<span class="lineno">541</span>        <span class="bp">self</span><span class="o">.</span><span class="n">state_modules</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-135'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-135'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">543</span>    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span>
<span class="lineno">544</span>        <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span>
<span class="lineno">545</span>        <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-136'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-136'>#</a>
            </div>
            <p><code  class="highlight"><span></span><span class="n">mask</span></code>
をデバイスに移動し<code  class="highlight"><span></span><span class="n">data</span></code>
、シーケンスとバッチディメンションを入れ替えます。<code  class="highlight"><span></span><span class="n">data</span></code>
<code  class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="mi">5</span><span class="p">]</span></code>
形があって、<code  class="highlight"><span></span><span class="n">mask</span></code>
形があるでしょう<code  class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">]</span></code>
。</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">550</span>        <span class="n">data</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">551</span>        <span class="n">mask</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-137'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-137'>#</a>
            </div>
            <p>トレーニングモードでのインクリメントステップ</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">554</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
<span class="lineno">555</span>            <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-138'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-138'>#</a>
            </div>
            <p>ストロークのシーケンスをエンコード</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">558</span>        <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s2">&quot;encoder&quot;</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-139'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-139'>#</a>
            </div>
            <p>取得<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqcl" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="">μ</span></span></span></span></span></span>、および <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqbi" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.03588em">σ</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord" style="">^</span></span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">560</span>            <span class="n">z</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">sigma_hat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-140'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-140'>#</a>
            </div>
            <p>複数のディストリビューションをデコードし、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord accent" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqck" style="margin-right:0.03588em">q</span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.19444em;"><span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">563</span>        <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s2">&quot;decoder&quot;</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-141'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-141'>#</a>
            </div>
            <p>連結 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqn" style=""><span class="mopen" style="">[</span><span class="mord" style=""><span class="mopen coloredeq eqt" style="">(</span><span class="mord coloredeq eqt" style=""><span class="mord coloredeq eqba" style=""><span class="mord coloredeq eqbc" style=""><span class="mord coloredeq eqbm" style="">Δ</span><span class="mord mathnormal coloredeq eqbm" style="">x</span></span><span class="mpunct coloredeq eqbc" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbc" style="">Δ</span></span><span class="mord mathnormal coloredeq eqba" style="margin-right:0.03588em">y</span></span><span class="mpunct coloredeq eqt" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqt" style=""><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqcd" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqci" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct coloredeq eqbg" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct coloredeq eqbg" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqbg" style=""><span class="mord coloredeq eqcf" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mclose coloredeq eqt" style="">)</span></span><span class="mpunct" style="">;</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcl" style="margin-right:0.04398em">z</span></span><span class="mclose" style="">]</span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">565</span>            <span class="n">z_stack</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">566</span>            <span class="n">inputs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">data</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">z_stack</span><span class="p">],</span> <span class="mi">2</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-142'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-142'>#</a>
            </div>
            <p>複数のディストリビューションを組み合わせて <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord accent" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqck" style="margin-right:0.03588em">q</span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord" style="">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.19444em;"><span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">568</span>            <span class="n">dist</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-143'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-143'>#</a>
            </div>
            <p>損失の計算</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">571</span>        <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;loss&#39;</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-144'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-144'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbv" style=""><span class="mord" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.32833099999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span><span class="mord mathnormal mtight" style="">L</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">573</span>            <span class="n">kl_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_div_loss</span><span class="p">(</span><span class="n">sigma_hat</span><span class="p">,</span> <span class="n">mu</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-145'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-145'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbz" style=""><span class="mord" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.32833099999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.00773em">R</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">575</span>            <span class="n">reconstruction_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reconstruction_loss</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">:],</span> <span class="n">dist</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-146'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-146'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord mathnormal">L</span><span class="mord mathnormal">oss</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbz" style=""><span class="mord" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.32833099999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.00773em">R</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbw" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.32833099999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.02691em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span><span class="mord mathnormal mtight" style="">L</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqbv" style=""><span class="mord" style=""><span class="mord mathnormal" style="">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.32833099999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span><span class="mord mathnormal mtight" style="">L</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">577</span>            <span class="n">loss</span> <span class="o">=</span> <span class="n">reconstruction_loss</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_div_loss_weight</span> <span class="o">*</span> <span class="n">kl_loss</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-147'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-147'>#</a>
            </div>
            <p>トラックロス</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">580</span>            <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.kl.&quot;</span><span class="p">,</span> <span class="n">kl_loss</span><span class="p">)</span>
<span class="lineno">581</span>            <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.reconstruction.&quot;</span><span class="p">,</span> <span class="n">reconstruction_loss</span><span class="p">)</span>
<span class="lineno">582</span>            <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.total.&quot;</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-148'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-148'>#</a>
            </div>
            <p>トレーニング状態の場合のみ</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">585</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-149'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-149'>#</a>
            </div>
            <p>オプティマイザを実行</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">587</span>            <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;optimize&#39;</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-150'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-150'>#</a>
            </div>
            <p>0 <code  class="highlight"><span></span><span class="n">grad</span></code>
 に設定</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">589</span>                <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-151'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-151'>#</a>
            </div>
            <p>勾配の計算</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">591</span>                <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-152'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-152'>#</a>
            </div>
            <p>モデルパラメーターと勾配をログに記録する</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">593</span>                <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">:</span>
<span class="lineno">594</span>                    <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">encoder</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="n">decoder</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-153'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-153'>#</a>
            </div>
            <p>クリップグラデーション</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">596</span>                <span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_clip</span><span class="p">)</span>
<span class="lineno">597</span>                <span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_clip</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-154'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-154'>#</a>
            </div>
            <p>最適化</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">599</span>                <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="lineno">600</span>
<span class="lineno">601</span>        <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-155'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-155'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">603</span>    <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-156'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-156'>#</a>
            </div>
            <p>検証データセットからエンコーダーにサンプルをランダムに選択</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">605</span>        <span class="n">data</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_dataset</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_dataset</span><span class="p">))]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-157'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-157'>#</a>
            </div>
            <p>バッチディメンションを追加してデバイスに移動</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">607</span>        <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-158'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-158'>#</a>
            </div>
            <p>[サンプル]</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">609</span>        <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-159'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-159'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">612</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span>
<span class="lineno">613</span>    <span class="n">configs</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span>
<span class="lineno">614</span>    <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;sketch_rnn&quot;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-160'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-160'>#</a>
            </div>
            <p>設定の辞書を渡す</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">617</span>    <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">configs</span><span class="p">,</span> <span class="p">{</span>
<span class="lineno">618</span>        <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">,</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-161'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-161'>#</a>
            </div>
            <p>学習率をに設定しているのは、<code  class="highlight"><span></span><span class="mf">1e-3</span></code>
より早く結果を確認できるからです。論文は提案していた<code  class="highlight"><span></span><span class="mf">1e-4</span></code>
。</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">621</span>        <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">1e-3</span><span class="p">,</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-162'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-162'>#</a>
            </div>
            <p>データセットの名前</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">623</span>        <span class="s1">&#39;dataset_name&#39;</span><span class="p">:</span> <span class="s1">&#39;bicycle&#39;</span><span class="p">,</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-163'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-163'>#</a>
            </div>
            <p>トレーニング、検証、サンプリングを切り替えるためのエポック内の内部反復回数。</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">625</span>        <span class="s1">&#39;inner_iterations&#39;</span><span class="p">:</span> <span class="mi">10</span>
<span class="lineno">626</span>    <span class="p">})</span>
<span class="lineno">627</span>
<span class="lineno">628</span>    <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-164'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-164'>#</a>
            </div>
            <p>実験を実行する</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">630</span>        <span class="n">configs</span><span class="o">.</span><span class="n">run</span><span class="p">()</span>
<span class="lineno">631</span>
<span class="lineno">632</span>
<span class="lineno">633</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
<span class="lineno">634</span>    <span class="n">main</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='footer'>
        <a href="https://papers.labml.ai">Trending Research Papers</a>
        <a href="https://labml.ai">labml.ai</a>
    </div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
    function handleImages() {
        var images = document.querySelectorAll('p>img')

        for (var i = 0; i < images.length; ++i) {
            handleImage(images[i])
        }
    }

    function handleImage(img) {
        img.parentElement.style.textAlign = 'center'

        var modal = document.createElement('div')
        modal.id = 'modal'

        var modalContent = document.createElement('div')
        modal.appendChild(modalContent)

        var modalImage = document.createElement('img')
        modalContent.appendChild(modalImage)

        var span = document.createElement('span')
        span.classList.add('close')
        span.textContent = 'x'
        modal.appendChild(span)

        img.onclick = function () {
            console.log('clicked')
            document.body.appendChild(modal)
            modalImage.src = img.src
        }

        span.onclick = function () {
            document.body.removeChild(modal)
        }
    }

    handleImages()
</script>
</body>
</html>